import os

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
import warnings

warnings.filterwarnings('ignore')
import tensorflow as tf

tf.compat.v1.logging.set_verbosity(40)

from tensorflow.keras import layers, models, Sequential, optimizers, losses
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Dense, BatchNormalization, Activation, GlobalAveragePooling2D


# 继承Layer,建立resnet18和34卷积层模块
class CellBlock(models.Model):
    def __init__(self, filter_num, strides=1):
        super(CellBlock, self).__init__()

        self.conv1 = Conv2D(filter_num, (3, 3), strides=strides, padding='same')
        self.bn1 = BatchNormalization()
        self.relu = Activation('relu')

        self.conv2 = Conv2D(filter_num, (3, 3), strides=1, padding='same')
        self.bn2 = BatchNormalization()

        if strides != 1:
            self.residual = Conv2D(filter_num, (1, 1), strides=strides)
        else:
            self.residual = lambda x: x

    def call(self, inputs, training=None):

        x = self.conv1(inputs)
        x = self.bn1(x)
        x = self.relu(x)

        x = self.conv2(x)
        x = self.bn2(x)

        r = self.residual(inputs)

        x = layers.add([x, r])
        output = tf.nn.relu(x)

        return output


# 继承Model， 创建resnet18和34
class ResNet(models.Model):
    def __init__(self, layers_dims, nb_classes):
        super(ResNet, self).__init__()

        self.stem = Sequential([
            Conv2D(64, (7, 7), strides=(2, 2), padding='same'),
            BatchNormalization(),
            Activation('relu'),
            MaxPooling2D((3, 3), strides=(2, 2), padding='same')
        ])

        # 开始模块
        self.layer1 = self.build_cellblock(64, layers_dims[0])
        self.layer2 = self.build_cellblock(128, layers_dims[1], strides=2)
        self.layer3 = self.build_cellblock(256, layers_dims[2], strides=2)
        self.layer4 = self.build_cellblock(512, layers_dims[3], strides=2)

        self.avgpool = GlobalAveragePooling2D()
        self.fc = Dense(nb_classes, activation='softmax')

    def call(self, inputs, training=None):
        x = self.stem(inputs)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = self.fc(x)

        return x

    def build_cellblock(self, filter_num, blocks, strides=1):
        res_blocks = Sequential()
        res_blocks.add(CellBlock(filter_num, strides))  # 每层第一个block stride可能为非1

        for _ in range(1, blocks):  # 每一层由多少个block组成
            res_blocks.add(CellBlock(filter_num, strides=1))

        return res_blocks


def build_ResNet(NetName, nb_classes):
    ResNet_Config = {'ResNet18': [2, 2, 2, 2],
                     'ResNet34': [3, 4, 6, 3]}

    return ResNet(ResNet_Config[NetName], nb_classes)


if __name__ == '__main__':

    # 合理分配GPU内存使用
    gpus = tf.config.experimental.list_physical_devices(device_type='GPU')
    for gpu in gpus:
        tf.config.experimental.set_memory_growth(gpu, True)

    (trainx, trainy), (testx, testy) = tf.keras.datasets.cifar10.load_data()

    trainx, testx = trainx.reshape((-1, 32, 32, 3)) / 255.0, testx.reshape((-1, 32, 32, 3)) / 255.0

    model = build_ResNet('ResNet18', 10)

    model.build(input_shape=(None, 32, 32, 3))

    model.summary()

    model.compile(loss=losses.sparse_categorical_crossentropy,
                  optimizer=optimizers.Adam(),
                  metrics=['accuracy'])

    import datetime

    stamp = datetime.datetime.now().strftime('%Y%m%d-%H%M%S')
    logdir = os.path.join('data', 'autograph', stamp)
    writer = tf.summary.create_file_writer(logdir)
    tensorboard_callback = tf.keras.callbacks.TensorBoard(logdir, histogram_freq=1)

    history = model.fit(trainx, trainy, batch_size=512, epochs=10, callbacks=[tensorboard_callback])

    import matplotlib.pyplot as plt

    plt.plot(history.history['accuracy'])
    plt.legend(['training'], loc='upper left')
    plt.show()

    result = model.evaluate(testx, testy)
    print('test_loss:', result[0])
    print('test_acc:', result[1])
